# Copyright (c) 2017, Apple Inc. All rights reserved.
#
# Use of this source code is governed by a BSD-3-clause license that can be
# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause

import atexit as _atexit
import json
import os as _os
import shutil as _shutil
import tempfile as _tempfile
import warnings as _warnings
from copy import deepcopy as _deepcopy
from typing import Dict as _Dict
from typing import List as _List
from typing import Optional as _Optional

import numpy as _np
import numpy as _numpy

from coremltools import ComputeUnit as _ComputeUnit
from coremltools import ReshapeFrequency as _ReshapeFrequency
from coremltools import SpecializationStrategy as _SpecializationStrategy
from coremltools import _logger as logger
from coremltools import proto as _proto
from coremltools._deps import _HAS_TF_1, _HAS_TF_2, _HAS_TORCH
from coremltools.converters.mil.mil.program import Program as _Program
from coremltools.converters.mil.mil.scope import ScopeSource as _ScopeSource

from .utils import (
    _MLMODEL_EXTENSION,
    _MLPACKAGE_EXTENSION,
    _create_mlpackage,
    _get_model_spec_path,
    _has_custom_layer,
    _is_macos,
    _macos_version,
    _try_get_weights_dir_path,
)
from .utils import load_spec as _load_spec
from .utils import save_spec as _save_spec

if _HAS_TORCH:
    import torch as _torch

if _HAS_TF_1 or _HAS_TF_2:
    import tensorflow as _tf


try:
    from ..libmodelpackage import ModelPackage as _ModelPackage
except:
    _ModelPackage = None

try:
    from ..libcoremlpython import _MLModelProxy
except Exception as e:
    logger.warning(f"Failed to load _MLModelProxy: {e}")
    _MLModelProxy = None


try:
    from ..libcoremlpython import _MLModelAssetProxy
except Exception as e:
    logger.warning(f"Failed to load _MLModelAssetProxy: {e}")
    _MLModelAssetProxy = None

_HAS_PIL = True
try:
    from PIL import Image as _PIL_IMAGE
except:
    _HAS_PIL = False


_MLMODEL_FULL_PRECISION = "float32"
_MLMODEL_HALF_PRECISION = "float16"
_MLMODEL_QUANTIZED = "quantized_model"

_VALID_MLMODEL_PRECISION_TYPES = [
    _MLMODEL_FULL_PRECISION,
    _MLMODEL_HALF_PRECISION,
    _MLMODEL_QUANTIZED,
]

# Linear quantization
_QUANTIZATION_MODE_LINEAR_QUANTIZATION = "_linear_quantization"
# Linear quantization represented as a lookup table
_QUANTIZATION_MODE_LOOKUP_TABLE_LINEAR = "_lookup_table_quantization_linear"
# Lookup table quantization generated by K-Means
_QUANTIZATION_MODE_LOOKUP_TABLE_KMEANS = "_lookup_table_quantization_kmeans"
# Custom lookup table quantization
_QUANTIZATION_MODE_CUSTOM_LOOKUP_TABLE = "_lookup_table_quantization_custom"
# Dequantization
_QUANTIZATION_MODE_DEQUANTIZE = "_dequantize_network"  # used for testing
# Symmetric linear quantization
_QUANTIZATION_MODE_LINEAR_SYMMETRIC = "_linear_quantization_symmetric"

_SUPPORTED_QUANTIZATION_MODES = [
    _QUANTIZATION_MODE_LINEAR_QUANTIZATION,
    _QUANTIZATION_MODE_LOOKUP_TABLE_LINEAR,
    _QUANTIZATION_MODE_LOOKUP_TABLE_KMEANS,
    _QUANTIZATION_MODE_CUSTOM_LOOKUP_TABLE,
    _QUANTIZATION_MODE_DEQUANTIZE,
    _QUANTIZATION_MODE_LINEAR_SYMMETRIC,
]

_LUT_BASED_QUANTIZATION = [
    _QUANTIZATION_MODE_LOOKUP_TABLE_LINEAR,
    _QUANTIZATION_MODE_LOOKUP_TABLE_KMEANS,
    _QUANTIZATION_MODE_CUSTOM_LOOKUP_TABLE,
]

_METADATA_CONVERSION_DATE = "com.github.apple.coremltools.conversion_date"
_METADATA_SOURCE = "com.github.apple.coremltools.source"
_METADATA_SOURCE_DIALECT = "com.github.apple.coremltools.source_dialect"
_METADATA_VERSION = "com.github.apple.coremltools.version"

from .compute_device import MLComputeDevice as _MLComputeDevice


def _verify_optimization_hint_input(optimization_hint_input: _Optional[dict] = None) -> None:
    """
    Throws an exception if ``optimization_hint_input`` is not valid.
    """
    if optimization_hint_input is None:
        return
    if not isinstance(optimization_hint_input, dict):
        raise TypeError('"optimization_hint_input" must be a dictionary or None')

    if optimization_hint_input != {} and _macos_version() < (15, 0):
        raise ValueError('Optimization hints are only available on macOS >= 15.0')

    for k in optimization_hint_input.keys():
        if k not in (
            "allowLowPrecisionAccumulationOnGPU",
            "reshapeFrequency",
            "specializationStrategy",
        ):
            raise ValueError(f"Unrecognized key in optimization_hint dictionary: {k}")

    if "allowLowPrecisionAccumulationOnGPU" in optimization_hint_input and not isinstance(
        optimization_hint_input["allowLowPrecisionAccumulationOnGPU"], bool
    ):
        raise TypeError(
            '"allowLowPrecisionAccumulationOnGPU" value of "optimization_hint_input" dictionary must be of type bool'
        )

    if "specializationStrategy" in optimization_hint_input and not isinstance(optimization_hint_input["specializationStrategy"], _SpecializationStrategy):
        raise TypeError('"specializationStrategy" value of "optimization_hint_input" dictionary must be of type coremltools.SpecializationStrategy')

    if "reshapeFrequency" in optimization_hint_input and not isinstance(optimization_hint_input["reshapeFrequency"], _ReshapeFrequency):
        raise TypeError('"reshapeFrequency" value of "optimization_hint_input" dictionary must be of type coremltools.ReshapeFrequency')



class _FeatureDescription:
    def __init__(self, fd_spec):
        self._fd_spec = fd_spec

    def __repr__(self):
        return "Features(%s)" % ",".join(map(lambda x: x.name, self._fd_spec))

    def __len__(self):
        return len(self._fd_spec)

    def __getitem__(self, key):
        for f in self._fd_spec:
            if key == f.name:
                return f.shortDescription
        raise KeyError("No feature with name %s." % key)

    def __contains__(self, key):
        for f in self._fd_spec:
            if key == f.name:
                return True
        return False

    def __setitem__(self, key, value):
        for f in self._fd_spec:
            if key == f.name:
                f.shortDescription = value
                return
        raise AttributeError("No feature with name %s." % key)

    def __iter__(self):
        for f in self._fd_spec:
            yield f.name


class MLState:
    def __init__(self, proxy) -> None:
        """
        Holds state for an MLModel.

        The MLState class provides methods to read and write model state.

        See Also
        --------
        ct.MLModel.predict
        """
        self.__proxy__ = proxy

    def read_state(
        self,
        name: str,
    ) -> _np.ndarray:
        """
        Retrieve the value of a model state variable.

        Parameters
        ----------
        name : str
            The name of the state variable to read.

        Returns
        -------
        numpy.ndarray
            The value of the specified state variable as a NumPy array.

        Raises
        ------
        RuntimeError
            If the state cannot be read (e.g. invalid state name).
        """
        return self.__proxy__.read_state(name)

    def write_state(
        self,
        name: str,
        value: _np.ndarray,
    ):
        """
        Set the value of a model state variable.

        Parameters
        ----------
        name : str
            The name of the state variable to write.

        value : numpy.ndarray
            The new value to assign to the state variable.

        Raises
        ------
        RuntimeError
            If the state cannot be written (e.g. invalid value, or invalid state name).
        """
        return self.__proxy__.write_state(name, value)


class MLModelAsset:
    """
    A class representing a compiled model asset.

    It supports two initialization methods:
    - From a compiled model directory: The directory should have a '.mlmodelc' extension.
    - From memory: Allows direct initialization using in-memory model data.
    """

    def __init__(self, proxy) -> None:
        if _MLModelAssetProxy is None or not isinstance(proxy, _MLModelAssetProxy):
            raise TypeError("The proxy parameter must be of type _MLModelAssetProxy.")
        self.__proxy__ = proxy

    @classmethod
    def from_path(
        cls,
        compiled_model_path: str,
    ) -> "MLModelAsset":
        """
        Create an MLModelAsset instance from a compiled model path.

        Parameters
        ----------
        compiled_model_path : str
            The file path to the compiled model.

        Returns
        -------
        MLModelAsset
            An instance of MLModelAsset created from the specified path.
        """
        return _MLModelProxy.create_model_asset_from_path(compiled_model_path)

    @classmethod
    def from_memory(
        cls,
        spec_data: bytes,
        blob_mapping: _Dict[str, bytes] = {},
    ) -> "MLModelAsset":
        """
        Create an MLModelAsset instance from in-memory data.

        Parameters
        ----------
        spec_data : bytes
            The specification data of the model.

        blob_mapping : Dict[str, bytes])
            A dictionary with blob path as the key and blob data as the value.

        Returns
        -------
        MLModelAsset
            An instance of MLModelAsset created from the provided memory data.
        """
        return _MLModelProxy.create_model_asset_from_memory(spec_data, blob_mapping)

class MLModel:
    """
    This class defines the minimal interface to a Core ML object in Python.

    At a high level, the protobuf specification consists of:

    - Model description: Encodes names and type information of the inputs and outputs to the model.
    - Model parameters: The set of parameters required to represent a specific instance of the model.
    - Metadata: Information about the origin, license, and author of the model.

    With this class, you can inspect a Core ML model, modify metadata, and make
    predictions for the purposes of testing (on select platforms).

    Examples
    --------
    .. sourcecode:: python

        # Load the model
        model = MLModel("HousePricer.mlmodel")

        # Set the model metadata
        model.author = "Author"
        model.license = "BSD"
        model.short_description = "Predicts the price of a house in the Seattle area."

        # Get the interface to the model
        model.input_description
        model.output_description

        # Set feature descriptions manually
        model.input_description["bedroom"] = "Number of bedrooms"
        model.input_description["bathrooms"] = "Number of bathrooms"
        model.input_description["size"] = "Size (in square feet)"

        # Set
        model.output_description["price"] = "Price of the house"

        # Make predictions
        predictions = model.predict({"bedroom": 1.0, "bath": 1.0, "size": 1240})

        # Get the spec of the model
        spec = model.get_spec()

        # Save the model
        model.save("HousePricer.mlpackage")

        # Load the model from the spec object
        spec = model.get_spec()
        # modify spec (e.g. rename inputs/outputs etc)
        model = MLModel(spec)
        # if model type is mlprogram, i.e. spec.WhichOneof('Type') == "mlProgram", then:
        model = MLModel(spec, weights_dir=model.weights_dir)

        # Load a non-default function from a multifunction .mlpackage
        model = MLModel("MultifunctionModel.mlpackage", function_name="deep_features")

    See Also
    --------
    predict
    """

    def __init__(
        self,
        model,
        is_temp_package=False,
        mil_program=None,
        skip_model_load=False,
        compute_units=_ComputeUnit.ALL,
        weights_dir=None,
        function_name=None,
        optimization_hints: _Optional[dict] = None,
    ) -> None:
        """
        Construct an MLModel from an ``.mlmodel``.

        Parameters
        ----------
        model: str or Model_pb2

            For an ML program (``mlprogram``), the model can be a path string (``.mlpackage``) or ``Model_pb2``.
            If it is a path string, it must point to a directory containing bundle
            artifacts (such as ``weights.bin``).
            If it is of type ``Model_pb2`` (spec), then you must also provide ``weights_dir`` if the model
            has weights, because both the proto spec and the weights are
            required to initialize and load the model.
            The proto spec for an ``mlprogram``, unlike a neural network (``neuralnetwork``),
            does not contain the weights; they are stored separately.
            If the model does not have weights, you can provide an empty ``weights_dir``.

            For non- ``mlprogram`` model types, the model can be a path string (``.mlmodel``)
            or type ``Model_pb2``, such as a spec object.

        is_temp_package: bool
            Set to ``True`` if the input model package dir is temporary and can be deleted upon interpreter termination.

        mil_program: coremltools.converters.mil.Program
            Set to the MIL program object, if available.
            It is available whenever an MLModel object is constructed using
            the unified converter API `coremltools.convert() <https://apple.github.io/coremltools/source/coremltools.converters.convert.html>`_.

        skip_model_load: bool
            Set to ``True`` to prevent Core ML Tools from calling into the Core ML framework
            to compile and load the model. In that case, the returned model object cannot
            be used to make a prediction. This flag may be used to load a newer model
            type on an older Mac, to inspect or load/save the spec.

            Example: Loading an ML program model type on a macOS 11, since an ML program can be
            compiled and loaded only from macOS12+.

            Defaults to ``False``.

        compute_units: coremltools.ComputeUnit
            The set of processing units the model can use to make predictions.

            An enum with four possible values:
                - ``coremltools.ComputeUnit.ALL``: Use all compute units available, including the
                  neural engine.
                - ``coremltools.ComputeUnit.CPU_ONLY``: Limit the model to only use the CPU.
                - ``coremltools.ComputeUnit.CPU_AND_GPU``: Use both the CPU and GPU,
                  but not the neural engine.
                - ``coremltools.ComputeUnit.CPU_AND_NE``: Use both the CPU and neural engine, but
                  not the GPU. Available only for macOS >= 13.0.

        weights_dir: str
            Path to the weight directory, required when loading an MLModel of type ``mlprogram``,
            from a spec object, such as when the argument ``model`` is of type ``Model_pb2``.

        function_name : str
            The name of the function from ``model`` to load.
            If not provided, ``function_name`` will be set to the ``defaultFunctionName`` in the proto.

        optimization_hints : dict or None
            Keys are the names of the optimization hint: 'allowLowPrecisionAccumulationOnGPU', 'reshapeFrequency'
                or 'specializationStrategy'.

            - 'allowLowPrecisionAccumulationOnGPU' value must have ``bool`` type.
            - 'reshapeFrequency' value must have ``coremltools.ReshapeFrequency`` type.
            - 'specializationStrategy' must have``coremltools.SpecializationStrategy`` type.

        Notes
        -----
        Internally this maintains the following:

        - ``_MLModelProxy``: A pybind wrapper around
          CoreML::Python::Model (see
          `coremltools/coremlpython/CoreMLPython.mm <https://github.com/apple/coremltools/blob/main/coremlpython/CoreMLPython.mm>`_)

        - ``package_path`` (mlprogram only): Directory containing all artifacts (``.mlmodel``,
          weights, and so on).

        - ``weights_dir`` (mlprogram only): Directory containing weights inside the package_path.

        Examples
        --------
        .. sourcecode:: python

            loaded_model = MLModel("my_model.mlmodel")
            loaded_model = MLModel("my_model.mlpackage")

        """

        def cleanup(package_path):
            if _os.path.exists(package_path):
                _shutil.rmtree(package_path)

        def does_model_contain_mlprogram(model) -> bool:
            """
            Is this an mlprogram or is it a pipeline with at least one mlprogram?
            """
            model_type = model.WhichOneof("Type")

            if model_type == "mlProgram":
                return True
            elif model_type not in ("pipeline", "pipelineClassifier", "pipelineRegressor"):
                return False

            # Does this pipeline contain an mlprogram?
            if model_type == "pipeline":
                pipeline_models = model.pipeline.models
            elif model_type == "pipelineClassifier":
                pipeline_models = model.pipelineClassifier.pipeline.models
            else:
                assert model_type == "pipelineRegressor"
                pipeline_models = model.pipelineRegressor.pipeline.models

            for m in pipeline_models:
                if does_model_contain_mlprogram(m):
                    return True
            return False

        if not isinstance(compute_units, _ComputeUnit):
            raise TypeError('"compute_units" parameter must be of type: coremltools.ComputeUnit')
        elif (compute_units == _ComputeUnit.CPU_AND_NE
              and _is_macos()
              and _macos_version() < (13, 0)
        ):
            raise ValueError(
                'coremltools.ComputeUnit.CPU_AND_NE is only available on macOS >= 13.0'
            )

        _verify_optimization_hint_input(optimization_hints)

        self.compute_unit = compute_units
        self.function_name = function_name
        if optimization_hints is not None:
            self.optimization_hints = optimization_hints.copy()
        else:
            self.optimization_hints = None

        self.is_package = False
        self.is_temp_package = False
        self.package_path = None
        self._weights_dir = None
        if mil_program is not None and not isinstance(mil_program, _Program):
            raise ValueError('"mil_program" must be of type "coremltools.converters.mil.Program"')
        self._mil_program = mil_program

        if isinstance(model, str):
            model = _os.path.abspath(_os.path.expanduser(_os.path.expandvars(model)))
            if _os.path.isdir(model):
                self.is_package = True
                self.package_path = model
                self.is_temp_package = is_temp_package
                self._weights_dir = _try_get_weights_dir_path(model)
            self.__proxy__, self._spec, self._framework_error = self._get_proxy_and_spec(
                model, compute_units, skip_model_load=skip_model_load, optimization_hints=optimization_hints,
            )
        elif isinstance(model, _proto.Model_pb2.Model):
            if does_model_contain_mlprogram(model):
                if model.WhichOneof("Type") == "mlProgram" and weights_dir is None:
                    raise Exception(
                        "MLModel of type mlProgram cannot be loaded just from the model spec object. "
                        "It also needs the path to the weights file. Please provide that as well, "
                        "using the 'weights_dir' argument."
                    )
                self.is_package = True
                self.is_temp_package = True
                filename = _create_mlpackage(model, weights_dir)
                self.package_path = filename
                self._weights_dir = _try_get_weights_dir_path(filename)
            else:
                filename = _tempfile.NamedTemporaryFile(suffix=_MLMODEL_EXTENSION, delete=False).name
                _save_spec(model, filename)

            self.__proxy__, self._spec, self._framework_error = self._get_proxy_and_spec(
                filename, compute_units, skip_model_load=skip_model_load, optimization_hints=optimization_hints
            )
            try:
                _os.remove(filename)
            except OSError:
                pass
        else:
            raise TypeError(
                "Expected model to be a .mlmodel file, .mlpackage file or a Model_pb2 object"
            )

        self._input_description = _FeatureDescription(self._spec.description.input)
        self._output_description = _FeatureDescription(self._spec.description.output)
        self._model_input_names_set = set([i.name for i in self._spec.description.input])

        if self.is_package and self.is_temp_package:
            _atexit.register(cleanup, self.package_path)

        # If function_name is not passed, self.function_name defaults to defaultFunctionName in the proto.
        default_function_name = self._spec.description.defaultFunctionName
        if self.function_name is None and len(default_function_name) > 0:
            self.function_name = default_function_name

        if self.function_name is not None:
            if not self._is_multifunction() and self.function_name != "main":
                raise ValueError('function_name must be "main" for non multifunction model')

        # Updated self._model_input_names_set based on self.function_name.
        # self._model_input_names_set defines the allowed input keys for the data dictionary passed to self.predict().
        if self.function_name is not None and self._is_multifunction():
            f = self._get_function_description(self.function_name)
            self._model_input_names_set = set([i.name for i in f.input])

    def _get_proxy_and_spec(
        self,
        filename: str,
        compute_units: _ComputeUnit,
        skip_model_load: _Optional[bool] = False,
        optimization_hints: _Optional[dict] = None,
    ):
        filename = _os.path.expanduser(filename)
        specification = _load_spec(filename)

        if _MLModelProxy and not skip_model_load:

            # check if the version is supported
            engine_version = _MLModelProxy.maximum_supported_specification_version()
            if specification.specificationVersion > engine_version:
                # in this case the specification is a newer kind of .mlmodel than this
                # version of the engine can support so we'll not try to have a proxy object
                return None, specification, None

            function_name = "" if self.function_name is None else self.function_name

            optimization_hints_str_vals = {}
            if optimization_hints is not None:
                for k, v in optimization_hints.items():
                    if isinstance(v, bool):
                        optimization_hints_str_vals[k] = str(v)
                    else:
                        optimization_hints_str_vals[k] = v.name

            try:
                return (
                    _MLModelProxy(
                        filename,
                        compute_units.name,
                        function_name,
                        optimization_hints_str_vals,
                        None,
                    ),
                    specification,
                    None,
                )
            except RuntimeError as e:
                _warnings.warn(
                    "You will not be able to run predict() on this Core ML model."
                    + " Underlying exception message was: "
                    + str(e),
                    RuntimeWarning,
                )
                return None, specification, e

        return None, specification, None


    @property
    def short_description(self):
        return self._spec.description.metadata.shortDescription

    @short_description.setter
    def short_description(self, short_description):
        self._spec.description.metadata.shortDescription = short_description

    @property
    def input_description(self):
        return self._input_description

    @property
    def output_description(self):
        return self._output_description

    @property
    def user_defined_metadata(self):
        return self._spec.description.metadata.userDefined

    @property
    def author(self):
        return self._spec.description.metadata.author

    @author.setter
    def author(self, author):
        self._spec.description.metadata.author = author

    @property
    def license(self):
        return self._spec.description.metadata.license

    @license.setter
    def license(self, license):
        self._spec.description.metadata.license = license

    @property
    def version(self):
        return self._spec.description.metadata.versionString

    @property
    def weights_dir(self):
        return self._weights_dir

    @version.setter
    def version(self, version_string):
        self._spec.description.metadata.versionString = version_string

    def __repr__(self):
        return self._spec.description.__repr__()

    def __str__(self):
        return self.__repr__()

    def save(self, save_path: str):
        """
        Save the model to an ``.mlmodel`` format. For an MIL program, the ``save_path`` is
        a package directory containing the ``mlmodel`` and weights.

        Parameters
        ----------
        save_path: Target file path / bundle directory for the model.

        Examples
        --------
        .. sourcecode:: python

            model.save("my_model_file.mlmodel")
            loaded_model = MLModel("my_model_file.mlmodel")

        """
        save_path = _os.path.expanduser(save_path)

        # Clean up existing file or directory.
        if _os.path.exists(save_path):
            if _os.path.isdir(save_path):
                _shutil.rmtree(save_path)
            else:
                _os.remove(save_path)

        if self.is_package:
            name, ext = _os.path.splitext(save_path)
            if not ext:
                save_path = "{}{}".format(save_path, _MLPACKAGE_EXTENSION)
            elif ext != _MLPACKAGE_EXTENSION:
                raise Exception(
                    "For an ML Program, extension must be {} (not {}). Please see https://coremltools.readme.io/docs/unified-conversion-api#target-conversion-formats to see the difference between neuralnetwork and mlprogram model types.".format(
                        _MLPACKAGE_EXTENSION, ext
                    )
                )
            _shutil.copytree(self.package_path, save_path)

            if self._mil_program is not None and all(
                [
                    _ScopeSource.EXIR_DEBUG_HANDLE in function._essential_scope_sources for function in self._mil_program.functions.values()
                ]
            ):
                debug_handle_to_ops_mapping = (
                    self._mil_program.construct_debug_handle_to_ops_mapping()
                )
                if len(debug_handle_to_ops_mapping) > 0:
                    debug_handle_to_ops_mapping_as_json = json.dumps(
                        {
                            "version" : self.user_defined_metadata[_METADATA_VERSION],
                            "mapping" : debug_handle_to_ops_mapping,
                        }
                    )
                    saved_debug_handle_to_ops_mapping_path = _os.path.join(
                        save_path, "executorch_debug_handle_mapping.json"
                    )
                    with open(saved_debug_handle_to_ops_mapping_path, "w") as f:
                        f.write(debug_handle_to_ops_mapping_as_json)

            saved_spec_path = _get_model_spec_path(save_path)
            _save_spec(self._spec, saved_spec_path)
        else:
            _save_spec(self._spec, save_path)


    def get_compiled_model_path(self):
        """
        Returns the path for the underlying compiled ML Model.

        **Important**: This path is available only for the lifetime of this Python object. If you want
        the compiled model to persist, you need to make a copy.

        """
        if self.__proxy__ is None:
            raise Exception("This model was not loaded or compiled with the Core ML Framework.")

        return self.__proxy__.get_compiled_model_path()


    def get_spec(self):
        """
        Get a deep copy of the protobuf specification of the model.

        Returns
        -------
        model: Model_pb2
            Protobuf specification of the model.

        Examples
        --------
        .. sourcecode:: python

            spec = model.get_spec()

        """
        return _deepcopy(self._spec)


    def predict(self, data, state: _Optional[MLState] = None):
        """
        Return predictions for the model.

        Parameters
        ----------
        data: dict[str, value] or list[dict[str, value]]
            Dictionary of data to use for predictions, where the keys are the names of the input features.
            For batch predictons, use a list of such dictionaries.

            The following dictionary values types are acceptable: list, array, numpy.ndarray, tensorflow.Tensor
            and torch.Tensor.

        state : MLState
            Optional state object as returned by ``make_state()``.

        Returns
        -------
        dict[str, value]
            Predictions as a dictionary where each key is the output feature name.

        list[dict[str, value]]
            For batch prediction, returns a list of the above dictionaries.

        Examples
        --------
        .. sourcecode:: python

            data = {"bedroom": 1.0, "bath": 1.0, "size": 1240}
            predictions = model.predict(data)

            data = [
                {"bedroom": 1.0, "bath": 1.0, "size": 1240},
                {"bedroom": 4.0, "bath": 2.5, "size": 2400},
            ]
            batch_predictions = model.predict(data)

        """
        def verify_and_convert_input_dict(d):
            self._verify_input_dict(d)
            self._convert_tensor_to_numpy(d)
            # TODO: remove the following call when this is fixed: rdar://92239209
            self._update_float16_multiarray_input_to_float32(d)

        if self.is_package and _is_macos() and _macos_version() < (12, 0):
            raise Exception(
                "predict() for .mlpackage is not supported in macOS version older than 12.0."
            )
        MLModel._check_predict_data(data)

        if self.__proxy__:
            return self._get_predictions(self.__proxy__,
                                         verify_and_convert_input_dict,
                                         data,
                                         state)
        else:   # Error case
            if _macos_version() < (10, 13):
                raise Exception(
                    "Model prediction is only supported on macOS version 10.13 or later."
                )

            if not _MLModelProxy:
                raise Exception("Unable to load CoreML.framework. Cannot make predictions.")
            elif (
                _MLModelProxy.maximum_supported_specification_version()
                < self._spec.specificationVersion
            ):
                engineVersion = _MLModelProxy.maximum_supported_specification_version()
                raise Exception(
                    "The specification has version "
                    + str(self._spec.specificationVersion)
                    + " but the Core ML framework version installed only supports Core ML model specification version "
                    + str(engineVersion)
                    + " or older."
                )
            elif _has_custom_layer(self._spec):
                raise Exception(
                    "This model contains a custom neural network layer, so predict is not supported."
                )
            else:
                if self._framework_error:
                    raise self._framework_error
                else:
                    raise Exception("Unable to load CoreML.framework. Cannot make predictions.")


    @staticmethod
    def _check_predict_data(data):
        if type(data) not in (list, dict):
            raise TypeError(
                f'"data" parameter must be either a dict or list of dict, but got {type(data)}.'
            )
        if type(data) == list and not all(map(lambda x: type(x) == dict, data)):
            raise TypeError("\"data\" list must contain only dictionaries")


    @staticmethod
    def _get_predictions(proxy, preprocess_method, data, state):
        if type(data) == dict:
            preprocess_method(data)
            state = None if state is None else state.__proxy__
            return proxy.predict(data, state)
        else:
            assert type(data) == list
            assert state is None, "State can only be used for unbatched predictions"
            for i in data:
                preprocess_method(i)
            return proxy.batchPredict(data)

    def _is_stateful(self) -> bool:
        model_desc = self._spec.description

        # For a single function model, we check if len(state) > 0
        if len(model_desc.functions) == 0:
            return len(model_desc.state) > 0

        # For a multifunction model, we first get the corresponding function description,
        # and check the state field.
        f = list(filter(lambda f: f.name == self.function_name, model_desc.functions))
        return len(f.state) > 0

    def _is_multifunction(self) -> bool:
        return len(self._spec.description.functions) > 0

    def _get_function_description(
        self, function_name: str
    ) -> "_proto.Model_pb2.FunctionDescription":
        f = list(filter(lambda f: f.name == function_name, self._spec.description.functions))

        if len(f) == 0:
            raise ValueError(f"function_name {function_name} not found in the model.")

        assert len(f) == 1, f"Invalid proto: two functions with the same name {function_name}."

        return f[0]

    def make_state(self) -> MLState:
        """
        Returns a new state object, which can be passed to the ``predict`` method.

        Returns
        _______
        state: MLState
            Holds state for an MLModel.

        State functionality is only supported on macOS 15+.

        Examples
        --------
        .. sourcecode:: python

            state = model.make_state()
            predictions = model.predict(x, state)

        See Also
        --------
        predict
        """
        if not _is_macos() or _macos_version() < (15, 0):
            raise Exception("State functionality is only supported on macOS 15+")
        if self.__proxy__ is None:
            raise Exception("This model was not loaded with the Core ML Framework. Cannot get state.")

        return MLState(self.__proxy__.newState())


    def _input_has_infinite_upper_bound(self) -> bool:
        """Check if any input has infinite upper bound (-1)."""
        for input_spec in self.input_description._fd_spec:
            for size_range in input_spec.type.multiArrayType.shapeRange.sizeRanges:
                if size_range.upperBound == -1:
                    return True
        return False

    def _set_build_info_mil_attributes(self, metadata):
        if self._spec.WhichOneof('Type') != "mlProgram":
            # No MIL attributes to set
            return

        ml_program_attributes = self._spec.mlProgram.attributes
        build_info_proto = ml_program_attributes["buildInfo"]

        # Set ValueType to dictionary of string to string
        str_type = _proto.MIL_pb2.ValueType()
        str_type.tensorType.dataType = _proto.MIL_pb2.DataType.STRING
        dict_type_str_to_str = _proto.MIL_pb2.ValueType()
        dict_type_str_to_str.dictionaryType.keyType.CopyFrom(str_type)
        dict_type_str_to_str.dictionaryType.valueType.CopyFrom(str_type)
        build_info_proto.type.CopyFrom(dict_type_str_to_str)

        # Copy the metadata
        build_info_dict = build_info_proto.immediateValue.dictionary
        for k, v in metadata.items():
            key_pair = _proto.MIL_pb2.DictionaryValue.KeyValuePair()
            key_pair.key.immediateValue.tensor.strings.values.append(k)
            key_pair.key.type.CopyFrom(str_type)
            key_pair.value.immediateValue.tensor.strings.values.append(v)
            key_pair.value.type.CopyFrom(str_type)
            build_info_dict.values.append(key_pair)


    def _get_mil_internal(self):
        """
        Get a deep copy of the MIL program object, if available.
        It's available whenever an MLModel object is constructed using
        the unified converter API [``coremltools.convert()``](https://apple.github.io/coremltools/source/coremltools.converters.mil.html#coremltools.converters._converters_entry.convert).

        Returns
        -------
        program: coremltools.converters.mil.Program

        Examples
        --------
        .. sourcecode:: python

            mil_prog = model._get_mil_internal()

        """
        return _deepcopy(self._mil_program)


    def _verify_input_dict(self, input_dict):
        # Check if the input name given by the user is valid.
        # Although this is checked during prediction inside CoreML Framework,
        # we still check it here to return early and
        # return a more verbose error message
        self._verify_input_name_exists(input_dict)

        # verify that the pillow image modes are correct, for image inputs
        self._verify_pil_image_modes(input_dict)


    def _verify_pil_image_modes(self, input_dict):
        if not _HAS_PIL:
            return
        for input_desc in self._spec.description.input:
            if input_desc.type.WhichOneof("Type") == "imageType":
                input_val = input_dict.get(input_desc.name, None)
                if not isinstance(input_val, _PIL_IMAGE.Image):
                    msg = "Image input, '{}' must be of type PIL.Image.Image in the input dict"
                    raise TypeError(msg.format(input_desc.name))
                if input_desc.type.imageType.colorSpace in (
                    _proto.FeatureTypes_pb2.ImageFeatureType.BGR,
                    _proto.FeatureTypes_pb2.ImageFeatureType.RGB,
                ):
                    if input_val.mode != "RGB":
                        msg = "RGB/BGR image input, '{}', must be of type PIL.Image.Image with mode=='RGB'"
                        raise TypeError(msg.format(input_desc.name))
                elif (
                    input_desc.type.imageType.colorSpace
                    == _proto.FeatureTypes_pb2.ImageFeatureType.GRAYSCALE
                ):
                    if input_val.mode != "L":
                        msg = "GRAYSCALE image input, '{}', must be of type PIL.Image.Image with mode=='L'"
                        raise TypeError(msg.format(input_desc.name))
                elif (
                    input_desc.type.imageType.colorSpace
                    == _proto.FeatureTypes_pb2.ImageFeatureType.GRAYSCALE_FLOAT16
                ):
                    if input_val.mode != "F":
                        msg = "GRAYSCALE_FLOAT16 image input, '{}', must be of type PIL.Image.Image with mode=='F'"
                        raise TypeError(msg.format(input_desc.name))


    def _verify_input_name_exists(self, input_dict):
        for given_input in input_dict.keys():
            if given_input not in self._model_input_names_set:
                err_msg = "Provided key \"{}\", in the input dict, " \
                          "does not match any of the model input name(s), which are: {}"
                raise KeyError(err_msg.format(given_input, self._model_input_names_set))


    @staticmethod
    def _update_float16_multiarray_input_to_float32(input_data: dict):
        for k, v in input_data.items():
            if isinstance(v, _np.ndarray) and v.dtype == _np.float16:
                input_data[k] = v.astype(_np.float32)

    def _convert_tensor_to_numpy(self, input_dict):
        def convert(given_input):
            if isinstance(given_input, _numpy.ndarray):
                sanitized_input = given_input
            elif _HAS_TORCH and isinstance(given_input, _torch.Tensor):
                sanitized_input = given_input.detach().numpy()
            elif (_HAS_TF_1 or _HAS_TF_2) and isinstance(given_input, _tf.Tensor):
                sanitized_input = given_input.eval(session=_tf.compat.v1.Session())
            else:
                sanitized_input = _numpy.array(given_input)
            return sanitized_input

        model_input_to_types = {}
        for inp in self._spec.description.input:
            type_value = inp.type.multiArrayType.dataType
            type_name = inp.type.multiArrayType.ArrayDataType.Name(type_value)
            if type_name != "INVALID_ARRAY_DATA_TYPE":
                model_input_to_types[inp.name] = type_name

        for given_input_name, given_input in input_dict.items():
            if given_input_name not in model_input_to_types:
                continue
            input_dict[given_input_name] = convert(given_input)

    @classmethod
    def get_available_compute_devices(cls) -> _List[_MLComputeDevice]:
        """
        The list of available compute devices for CoreML.

        Use the method to get the list of compute devices that MLModel's predict method can use.

        Some compute devices on the hardware are exclusive to the domain ML frameworks such as Vision and SoundAnalysis and
        not available to Core ML framework. See also ``MLComputeDevice.get_all_compute_devices()``.

        Returns
        -------
        The list of compute devices MLModel's predict method can use.

        Examples
        --------
        .. sourcecode:: python

            compute_devices = coremltools.MLModel.get_available_compute_devices()

        """
        return _MLModelProxy.get_available_compute_devices()

    @property
    def load_duration_in_nano_seconds(self) -> _Optional[int]:
        """
        Retrieves the duration of the model loading process in nanoseconds.

        Notes
        -----
        Calculates the time elapsed during the model loading process, specifically
        measuring the execution time of ``[MLModel loadContentsOfURL:configuration:error:]`` method
        of the Core ML framework.

        Returns
        -------
        Optional[int]:
            The duration of the model loading process in nanoseconds.
            Returns None if duration is not available.
        """

        return self.__proxy__.get_load_duration_in_nano_seconds()

    @property
    def last_predict_duration_in_nano_seconds(self) -> _Optional[int]:
        """
        Retrieves the duration of the last predict operation in nanoseconds.
        This method returns the time taken for the most recent prediction made by
        the model, measured in nanoseconds.

        Notes
        -----
        Calculates the time elapsed during the model predict call, specifically
        measuring the execution time of ``[MLModel predictionFromFeatures:error:]``
        or ``[MLModel predictionFromBatch:error:]`` method of the Core ML framework.

        Returns
        -------
        Optional[int]:
            The duration of the last prediction operation in nanoseconds.
            Returns None if no prediction has been made yet.
        """

        return self.__proxy__.get_last_predict_duration_in_nano_seconds()
